package org.rocket.spring.security.browser.authentication;

import org.apache.commons.lang3.StringUtils;
import org.rocket.spring.security.core.validate.code.ImageValidateCode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.web.filter.OncePerRequestFilter;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import java.io.IOException;
import java.time.LocalDateTime;
import java.util.Objects;

/**
 * @author he peng
 * @create 2018/1/25 11:58
 * @see
 */
public class ImageValidateCodeAuthenticationFilter extends OncePerRequestFilter {

    private static final Logger LOG = LoggerFactory.getLogger(ImageValidateCodeAuthenticationFilter.class);

    private String filterRequestUri;

    private AuthenticationFailureHandler authenticationFailureHandler;

    public ImageValidateCodeAuthenticationFilter(String filterRequestUri, AuthenticationFailureHandler authenticationFailureHandler) {
        this.filterRequestUri = filterRequestUri;
        this.authenticationFailureHandler = authenticationFailureHandler;
    }

    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {

        if (! StringUtils.equals(request.getRequestURI() , filterRequestUri)) {
            filterChain.doFilter(request , response);
            return;
        }

        try {
            validateCode(request);
            filterChain.doFilter(request , response);
        } catch (ValidateCodeAuthenticationException e) {
            authenticationFailureHandler.onAuthenticationFailure(request , response , e);
            return;
        }
    }

    private void validateCode(HttpServletRequest request) {
        LOG.info("校验图形验证码");
        HttpSession session = request.getSession();
        ImageValidateCode imageValidateCode = (ImageValidateCode) session.getAttribute("IMAGE_VALIDATE_CODE");
        if (Objects.isNull(imageValidateCode)) {
            throw new ValidateCodeAuthenticationException("没有验证码");
        }

        String userInputImageCode = request.getParameter("imageCode");
        if (StringUtils.isBlank(userInputImageCode)) {
            throw new ValidateCodeAuthenticationException("用户没有输入验证码");
        }

        LocalDateTime expireTime = imageValidateCode.getExpireTime();
        if (expireTime.isBefore(LocalDateTime.now())) {
            session.removeAttribute("IMAGE_VALIDATE_CODE");
            throw new ValidateCodeAuthenticationException("验证码超时失效");
        }

        String validateCode = imageValidateCode.getCode();
        if (! StringUtils.endsWithIgnoreCase(userInputImageCode , validateCode)) {
            throw new ValidateCodeAuthenticationException("验证码输入错误");
        }

        session.removeAttribute("IMAGE_VALIDATE_CODE");
    }
}
